#lang racket


(require "coord.rkt"
         "math.rkt")

(provide
 matrix?
 
 m-lines
 m-cols
 m-frame
 m-cs
 m-n
 
 m-rotation
 m-scaling
 m-translation
 
 m-line
 m-column
 
 m-translation-c
 
 m-neg
 +m
 -m
 *m
 m*p
 
 m-transform
 
 m-zero
 m-identity
 
 list<-matrix
 vector-lines<-matrix
 vector-cols<-matrix)


(define (v*v v1 v2)
  (foldl + 0 (vector->list (vector-map * v1 v2))))


(define (matrix-write m port mode)  
  (define (print-line v)
    (write-string
     (string-join
      (vector->list
       (vector-map (λ (e) (format "~a" e)) v))
      " ")
     port))
  
  (write-string "m(" port)
  (print-line (vector-copy (matrix-vals m) 0 4))
  (newline port)
  
  (write-string "  " port)
  (print-line (vector-copy (matrix-vals m) 4 8))
  (newline port)
  
  (write-string "  " port)
  (print-line (vector-copy (matrix-vals m) 8 12))
  (newline port)
  
  (write-string "  " port)
  (print-line (vector-copy (matrix-vals m) 12 16))
  (write-string ")" port))

(struct matrix (vals)
  #:property prop:custom-write matrix-write)


(define (m-lines v1 v2 v3)
  (let ((x1 (vector-ref v1 0))
        (y1 (vector-ref v1 1))
        (z1 (vector-ref v1 2))
        (w1 (vector-ref v1 3)))
    (let ((x2 (vector-ref v2 0))
          (y2 (vector-ref v2 1))
          (z2 (vector-ref v2 2))
          (w2 (vector-ref v2 3)))
      (let ((x3 (vector-ref v3 0))
            (y3 (vector-ref v3 1))
            (z3 (vector-ref v3 2))
            (w3 (vector-ref v3 3)))
        (matrix
         (vector x1 y1 z1 w1
                 x2 y2 z2 w2
                 x3 y3 z3 w3
                 0  0  0  1))))))

(define m-cols
  (case-lambda
    ((v1 v2 v3 v4)
     (let ((x1 (vector-ref v1 0))
           (y1 (vector-ref v1 1))
           (z1 (vector-ref v1 2)))
       (let ((x2 (vector-ref v2 0))
             (y2 (vector-ref v2 1))
             (z2 (vector-ref v2 2)))
         (let ((x3 (vector-ref v3 0))
               (y3 (vector-ref v3 1))
               (z3 (vector-ref v3 2)))
           (let ((x4 (vector-ref v4 0))
                 (y4 (vector-ref v4 1))
                 (z4 (vector-ref v4 2)))
             (matrix
              (vector x1 x2 x3 x4
                      y1 y2 y3 y4
                      z1 z2 z3 z4
                      0  0  0  1)))))))
    ((v)
     (let ((x1 (vector-ref v 0))
           (y1 (vector-ref v 1))
           (z1 (vector-ref v 2)))
       (let ((x2 (vector-ref v 4))
             (y2 (vector-ref v 5))
             (z2 (vector-ref v 6)))
         (let ((x3 (vector-ref v 8))
               (y3 (vector-ref v 9))
               (z3 (vector-ref v 10)))
           (let ((x4 (vector-ref v 12))
                 (y4 (vector-ref v 13))
                 (z4 (vector-ref v 14)))
             (matrix
              (vector x1 x2 x3 x4
                      y1 y2 y3 y4
                      z1 z2 z3 z4
                      0  0  0  1)))))))))

(define (m-frame c x y)
  (let ((xc (norm x))
        (zc (norm (cross-c x y))))
    (let ((yc (norm (cross-c zc xc))))
      (m-cols
       (vector-of-coord xc)
       (vector-of-coord yc)
       (vector-of-coord zc)
       (vector-of-coord c)))))

(define (m-cs c x y)
  (m-frame c (sub-c x c) (sub-c y c)))

(define (m-n c n)
  (let ((z (norm n)))
    (let ((x (norm (collinear-cross-c z))))
      (let ((y (norm (cross-c z x))))
        (m-cols
         (vector-of-coord x)
         (vector-of-coord y)
         (vector-of-coord z)
         (vector-of-coord c))))))

(define (m-rotation a n)
  (define (m-rotate-aux n1 n2 n3)
    (let ((c (cos a))
          (s (sin a))
          (t (- 1 (cos a)))
          (p (norm (xyz n1 n2 n3))))
      (let ((x (xyz-x p))
            (y (xyz-y p))
            (z (xyz-z p)))
        (let ((r00 (+ (* t (^2 x)) c))
              (r01 (- (* t x y) (* s z)))
              (r02 (+ (* t x z) (* s y)))
              
              (r10 (+ (* t x y) (* s z)))
              (r11 (+ (* t (^2 y)) c))
              (r12 (- (* t y z) (* s x)))
              
              (r20 (- (* t x z) (* s y))) ;;AML: was (- (* t x y) (* s y))
              (r21 (+ (* t y z) (* s x)))
              (r22 (+ (* t (^2 z)) c)))
          (matrix
           (vector r00 r01 r02 0
                   r10 r11 r12 0
                   r20 r21 r22 0
                   0   0   0   1))))))
  (if (or (= a 0) (eq-c n u0))
      m-identity
      (apply m-rotate-aux (list-of-coord n))))

(define m-scaling
  (case-lambda
    ((x y z)
     (matrix
      (vector x 0 0 0
              0 y 0 0
              0 0 z 0
              0 0 0 1)))
    ((c)
     (m-scaling (xyz-x c) (xyz-y c) (xyz-z c)))))

(define m-translation
  (case-lambda
    ((x y z)
     (matrix
      (vector 1 0 0 x
              0 1 0 y
              0 0 1 z
              0 0 0 1)))
    ((c)
     (m-translation (xyz-x c) (xyz-y c) (xyz-z c)))))

(define (m-line m l)
  (vector-copy (matrix-vals m) (* l 4) (+ (* l 4) 4)))

(define (m-column m c)
  (let ((vals (matrix-vals m)))
    (vector
     (vector-ref vals c)
     (vector-ref vals (+ c 4))
     (vector-ref vals (+ c 8))
     (vector-ref vals (+ c 12)))))


(define (m-translation-c m)
  (let ((vals (matrix-vals m))
        (c 3))
    (xyz
     (vector-ref vals c)
     (vector-ref vals (+ c 4))
     (vector-ref vals (+ c 8)))))


(define (m-neg m)
  (matrix
   (vector-map - (matrix-vals m))))

(define (+m m1 m2)
  (matrix
   (vector-map + (matrix-vals m1) (matrix-vals m2))))

(define (-m m1 m2)
  (+m m1 (m-neg m2)))

(define (*m m1 m2)
  (matrix
   (list->vector
    (flatten
     (for/list ((i 4))
       (for/list ((j 4))
         (v*v (m-line m1 i) (m-column m2 j))))))))

; edit: should not give the user the option of distinguishing between coords and vectors
(define (m*p m p (vector? #f))
  (let* ((w (if vector? 0 1))
         (v (vector (xyz-x p) (xyz-y p) (xyz-z p) w)))
    (xyz
     (v*v (m-line m 0) v)
     (v*v (m-line m 1) v)
     (v*v (m-line m 2) v))))


(define (m-transform t a n s)
  (*m
   (*m
    (m-scaling (xyz-x s) (xyz-y s) (xyz-z s))
    (m-rotation a n))
   (m-translation (xyz-x t) (xyz-y t) (xyz-z t))))


(define m-zero
  (matrix
   (vector 0 0 0 0
           0 0 0 0
           0 0 0 0
           0 0 0 0)))

(define m-identity
  (matrix
   (vector 1 0 0 0
           0 1 0 0
           0 0 1 0
           0 0 0 1)))


(define (list<-matrix m)
  (vector->list (matrix-vals m)))

(define (vector-lines<-matrix m)
  (vector-copy (matrix-vals m)))

(define (vector-cols<-matrix m)
  (let ((v (matrix-vals m)))
    (let ((x1 (vector-ref v 0))
          (y1 (vector-ref v 1))
          (z1 (vector-ref v 2))
          (w1 (vector-ref v 3)))
      (let ((x2 (vector-ref v 4))
            (y2 (vector-ref v 5))
            (z2 (vector-ref v 6))
            (w2 (vector-ref v 7)))
        (let ((x3 (vector-ref v 8))
              (y3 (vector-ref v 9))
              (z3 (vector-ref v 10))
              (w3 (vector-ref v 11)))
          (vector x1 x2 x3 0
                  y1 y2 y3 0
                  z1 z2 z3 0
                  w1 w2 w3 1))))))